/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// SWIG typemaps and declarations for building, compiling, and
// executing XLA computations, wrapping most of what is declared in
// xla_data.h.
//
// The typemaps below implement/assert the following correspondences
// (with elaborations below):
//
//    C++                                  Python
// -------------------------------------+---------------------------------------
//  Span<int64>                        <-  sequence of int
//  vector<int>                        ->  sequence of int
//  Span<LocalOp>                      <-  sequence of LocalOp
//  Literal                            <-> (nested tuple of) numpy ndarray
//  std::vector<Literal>               <-  sequence of (nested tuple of) ndarray
//  Shape                               -> pair holding (dtype, dimensions)
//                                     <-  object duck-typed as xla_client.Shape
//  ProgramShape                       ->  pair of ([arg_shapes], ret_shape)
//  std::vector<Shape>                 <-  sequence of xla_client.Shape objects
//  PrimitiveType                      <-  int
//  Span<pair<int64, in64>>            <-  sequence of int pairs
//  PaddingConfig proto                <-  corresponding Python proto
//  ConvolutionDimensionNumbers proto  <-  corresponding Python proto
//  DotDimensionNumbers proto          <-  corresponding Python proto
//  GatherDimensionNumbers proto       <-  corresponding Python proto
//  ScatterDimensionNumbers proto      <-  corresponding Python proto
//  Span<ReplicaGroup proto>           <-  sequence of ReplicaGroup Python proto
//
// Arrows indicate whether a conversion only ever occurs in one
// direction, or whether it is maintained bidirectionally.
//
// The Python objects corresponding to C++ Literals have the type:
//
//   T = ndarray | (T, ...)
//
// where a terminal numpy ndarray translates to a Literal with a
// non-tuple Shape, an XLA primitive element type corresponding to the
// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates
// to a tuple-shaped Literal whose tuple components are translated
// recursively. For example, if x is a numpy ndarray in Python, with
// shape (2, 3) and dtype of dtype('float32'), then x translates to a
// Literal with rank 2, dimension 2 and 3, and XLA primitive type
// F32. Meanwhile,
//
//   (x, (x, x), (x,)),
//
// translates to a tuple-shaped XLA Literal, whose component subshapes
// are a 2x3 F32-shaped literal followed by two tuple-shaped literals.
//
// Shapes output by C++ become Python objects with the type:
//
//   T            = (dtype, S)
//   S            = DIMENSIONS | TUPLE_SHAPES
//   DIMENSIONS   = (int, ...)
//   TUPLE_SHAPES = (T, ...)
//
// In the pair described by the T rule, the terminal dtype determines
// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is
// dtype('O'), numpy's object dtype, the structure represents a tuple
// shape and the expansion of the non-terminal S is
// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type
// and S expands into DIMENSIONS giving dimension sizes. For example:
//
//   (dtype('float32'), (3, 5, 7))
//
// describes a 3x5x7 array of F32s, and
//
//   (dtype('O'), ((dtype('float32'), (2, 3)),
//                 (dtype('float64'), (4, 5))))
//
// describes a tuple shape with two subshapes: the first a 2x3 F32,
// and the other a 4x5 F64.
//
// The Python int corresponding to a PrimitiveType enum must be valid
// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32).
//
// The SWIG object wrappers generated by this file are not intended
// for end use, but rather for internal use in the Python XLA client,
// xla_client.py.
//
// One central reason for the Python-side indirection is that the
// Python-side objects produced by the typemaps in this file are
// further packaged up by xla_client before being passed on. For
// instance, the Python pair produced for a C++ Shape is further
// wrapped in a Python class (xla_client.Shape) so as not to expose
// the raw pair externally.
//
// Other SWIG object wrappers (e.g. of Computation) are further
// wrapped by xla_client in order to set up a custom destructor that
// triggers memory deallocation on the C++ side.

%module(threads="1") xla_data

// Keep the GIL except where explicitly specified.
%nothread;

%include "tensorflow/python/platform/base.i"

%{
// Must be included first
#include "tensorflow/python/lib/core/numpy.h"

#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/python/numpy_bridge.h"

using namespace xla;
using namespace xla::swig;

%}

// Basic types


%typemap(out) std::vector<int> {
  PyObject* out = PyList_New($1.size());
  for (int i = 0; i < $1.size(); ++i) {
    PyList_SET_ITEM(out, i, PyInt_FromLong($1[i]));
  }
  $result = out;
}

%typemap(out) StatusOr<bool> {
  if ($1.ok()) {
    $result = PyBool_FromLong($1.ConsumeValueOrDie());
  } else {
    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
    SWIG_fail;
  }
}

%typemap(out) StatusOr<string> {
  if ($1.ok()) {
    $result = PyString_FromString($1.ConsumeValueOrDie().c_str());
  } else {
    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
    SWIG_fail;
  }
}

%typemap(out) Status {
  if (!$1.ok()) {
    PyErr_SetString(
        PyExc_RuntimeError, $1.ToString().c_str());
    SWIG_fail;
  }
  Py_INCREF(Py_None);
  $result = Py_None;
}

%typemap(in) absl::Span<const int64>
    (std::vector<int64> temps) {
  if (!PySequence_Check($input)) {
    PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
    SWIG_fail;
  }
  const int size = PySequence_Size($input);
  temps.resize(size);
  for (int i = 0; i < size; ++i) {
    PyObject* o = PySequence_GetItem($input, i);
    PyObject* py_int = numpy::PyNumberToPyInt(o);
    if (!py_int) {
      PyErr_SetString(
          PyExc_TypeError,
          "Argument sequence element cannot be converted to int");
      Py_DECREF(o);
      SWIG_fail;
    }
    temps[i] = numpy::PyIntOrPyLongToLong(py_int);
    if (temps[i] == -1 && PyErr_Occurred()) {
      Py_DECREF(py_int);
      Py_DECREF(o);
      SWIG_fail;
    }
    Py_DECREF(py_int);
    Py_DECREF(o);
  }
  $1 = temps;
}

// Literal

%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
  literal_status = numpy::XlaLiteralFromPyObject($input);
  if (!literal_status.ok()) {
    PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
    SWIG_fail;
  }
  $1 = &literal_status.ValueOrDie();
}

%typemap(out) Literal (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
  obj_status = numpy::PyObjectFromXlaLiteral(*$1);
  if (!obj_status.ok()) {
    PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
    SWIG_fail;
  }
  $result = obj_status.ValueOrDie().release();
}

%typemap(out) StatusOr<Literal> (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
  if (!$1.ok()) {
    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
    SWIG_fail;
  }
  obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
  if (!obj_status.ok()) {
    PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
    SWIG_fail;
  }
  $result = obj_status.ValueOrDie().release();
}

%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
  if (!PySequence_Check($input)) {
    PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
    SWIG_fail;
  }
  const int size = PySequence_Size($input);
  for (int i = 0; i < size; ++i) {
    PyObject* o = PySequence_GetItem($input, i);
    StatusOr<Literal> literal_status = numpy::XlaLiteralFromPyObject(o);
    if (!literal_status.ok()) {
      PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
      Py_DECREF(o);
      SWIG_fail;
    }
    temps.push_back(literal_status.ConsumeValueOrDie());
    Py_DECREF(o);
  }
  $1 = &temps;
}

// OpMetadata

%typemap(in) const OpMetadata& (OpMetadata temp) {
  StatusOr<OpMetadata> statusor = numpy::OpMetadataFromPyObject($input);
  if (!statusor.ok()) {
    PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
    SWIG_fail;
  }
  temp = std::move(statusor).ValueOrDie();
  $1 = &temp;
}

// Shape

%typemap(out) const Shape& {
  $result = numpy::PyShapeInfoFromXlaShape(*$1).release();
}

%typemap(out) StatusOr<Shape> {
  if ($1.ok()) {
    $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()).release();
  } else {
    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
    SWIG_fail;
  }
}


%typemap(out) StatusOr<ProgramShape> {
  if ($1.ok()) {
    $result = numpy::PyProgramShapeInfoFromXlaProgramShape(
        $1.ConsumeValueOrDie()).release();
  } else {
    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
    SWIG_fail;
  }
}


%typemap(in) const Shape& (Shape temp) {
  StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
  if (!statusor.ok()) {
    PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
    SWIG_fail;
  }
  temp = std::move(statusor).ValueOrDie();
  $1 = &temp;
}

%typemap(in) const absl::optional<Shape>& (
    absl::optional<Shape> temp) {
  if ($input == Py_None) {
    temp = absl::nullopt;
    $1 = &temp;
  } else {
    StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
    if (!statusor.ok()) {
      PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
      SWIG_fail;
    }
    temp = std::move(statusor).ValueOrDie();
    $1 = &temp;
  }
}

%typemap(out) std::unique_ptr<Shape> {
  $result = numpy::PyShapeInfoFromXlaShape(*$1).release();
}

%typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) {
  if (!PySequence_Check($input)) {
    PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
    SWIG_fail;
  }
  const int size = PySequence_Size($input);
  for (int i = 0; i < size; ++i) {
    PyObject* o = PySequence_GetItem($input, i);
    StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
    Py_DECREF(o);
    if (!statusor.ok()) {
      PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
      SWIG_fail;
    }
    temps.push_back(statusor.ConsumeValueOrDie());
  }
  $1 = &temps;
}

%typemap(in) const std::vector<absl::optional<Shape> >& (
    std::vector<absl::optional<Shape> > temps) {
  if (!PySequence_Check($input)) {
    PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
    SWIG_fail;
  }
  const int size = PySequence_Size($input);
  for (int i = 0; i < size; ++i) {
    PyObject* o = PySequence_GetItem($input, i);
    if (o == Py_None) {
      temps.push_back(absl::nullopt);
    } else {
      StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
      Py_DECREF(o);
      if (!statusor.ok()) {
        PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
        SWIG_fail;
      }
      temps.push_back(statusor.ConsumeValueOrDie());
    }
  }
  $1 = &temps;
}

// PrimitiveType

%typemap(in) PrimitiveType {
  PyObject* py_int = numpy::PyNumberToPyInt($input);
  if (!py_int) {
    PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int");
    SWIG_fail;
  }
  const long value = numpy::PyIntOrPyLongToLong(py_int);
  if (value == -1 && PyErr_Occurred()) {
    Py_DECREF(py_int);
    SWIG_fail;
  }
  if (!PrimitiveType_IsValid(value)) {
    PyErr_SetString(
        PyExc_TypeError, "Argument not valid for PrimitiveType enum");
    Py_DECREF(py_int);
    SWIG_fail;
  }
  $1 = static_cast<PrimitiveType>(value);
}

// Span<pair<int64, in64>>

%typemap(in) absl::Span<const std::pair<int64, int64> >
    (std::vector<std::pair<int64, int64> > temps) {
  if (!PySequence_Check($input)) {
    PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
    SWIG_fail;
  }
  const int size = PySequence_Size($input);
  temps.reserve(size);
  for (int i = 0; i < size; ++i) {
    PyObject* o = PySequence_GetItem($input, i);
    if (!o) {
      SWIG_fail;
    }
    PyObject* first = PyTuple_GetItem(o, 0);
    if (!first) {
      Py_DECREF(o);
      SWIG_fail;
    }
    PyObject* first_pyint = numpy::PyNumberToPyInt(first);
    if (!first_pyint) {
      PyErr_SetString(
          PyExc_TypeError,
          "First pair item cannot be converted to int");
      Py_DECREF(o);
      SWIG_fail;
    }
    PyObject* second = PyTuple_GetItem(o, 1);
    if (!second) {
      Py_DECREF(o);
      Py_DECREF(first_pyint);
      SWIG_fail;
    }
    PyObject* second_pyint = numpy::PyNumberToPyInt(second);
    if (!second_pyint) {
      PyErr_SetString(
          PyExc_TypeError,
          "Second pair item cannot be converted to int");
      Py_DECREF(o);
      Py_DECREF(first_pyint);
      SWIG_fail;
    }
    const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint);
    if (first_value == -1 && PyErr_Occurred()) {
      Py_DECREF(o);
      Py_DECREF(first_pyint);
      Py_DECREF(second_pyint);
      SWIG_fail;
    }
    const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint);
    if (second_value == -1 && PyErr_Occurred()) {
      Py_DECREF(o);
      Py_DECREF(first_pyint);
      Py_DECREF(second_pyint);
      SWIG_fail;
    }
    temps.push_back(std::make_pair(first_value, second_value));
    Py_DECREF(o);
  }
  $1 = temps;
}

// DotDimensionNumbers

%typemap(in) const DotDimensionNumbers&
    (DotDimensionNumbers dimension_numbers) {
  if (!HandleRepeatedInt64Attribute(
        $input, "lhs_contracting_dimensions",
        dimension_numbers.mutable_lhs_contracting_dimensions())) {
    SWIG_fail;
  }
  if (!HandleRepeatedInt64Attribute(
        $input, "rhs_contracting_dimensions",
        dimension_numbers.mutable_rhs_contracting_dimensions())) {
    SWIG_fail;
  }
  if (!HandleRepeatedInt64Attribute(
        $input, "lhs_batch_dimensions",
        dimension_numbers.mutable_lhs_batch_dimensions())) {
    SWIG_fail;
  }
  if (!HandleRepeatedInt64Attribute(
        $input, "rhs_batch_dimensions",
        dimension_numbers.mutable_rhs_batch_dimensions())) {
    SWIG_fail;
  }

  $1 = &dimension_numbers;
}

// PaddingConfig

%typemap(in) const PaddingConfig&
    (PaddingConfig padding_config) {
  PyObject* dimensions = PyObject_GetAttrString($input, "dimensions");
  if (!dimensions) {
    SWIG_fail;
  }

  int length = PySequence_Size(dimensions);
  if (length == -1) {
    Py_DECREF(dimensions);
    SWIG_fail;
  }

  for (int i = 0; i < length; ++i) {
    PyObject* item = PySequence_GetItem(dimensions, i);
    if (!item) {
      Py_DECREF(dimensions);
      SWIG_fail;
    }
    int64 edge_padding_low, edge_padding_high, interior_padding;
    if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low)
        || !GetIntAttr(item, "edge_padding_high", &edge_padding_high)
        || !GetIntAttr(item, "interior_padding", &interior_padding)) {
      Py_DECREF(item);
      Py_DECREF(dimensions);
      SWIG_fail;
    }
    Py_DECREF(item);

    PaddingConfig::PaddingConfigDimension* dimension =
        padding_config.add_dimensions();
    dimension->set_edge_padding_low(edge_padding_low);
    dimension->set_edge_padding_high(edge_padding_high);
    dimension->set_interior_padding(interior_padding);
  }
  Py_DECREF(dimensions);

  $1 = &padding_config;
}

// ConvolutionDimensionNumbers

%typemap(in) const ConvolutionDimensionNumbers&
    (ConvolutionDimensionNumbers dimension_numbers) {
  int64 value;

  if (!GetIntAttr($input, "input_batch_dimension", &value)) {
    SWIG_fail;
  }
  dimension_numbers.set_input_batch_dimension(value);

  if (!GetIntAttr($input, "input_feature_dimension", &value)) {
    SWIG_fail;
  }
  dimension_numbers.set_input_feature_dimension(value);

  if (!GetIntAttr($input, "output_batch_dimension", &value)) {
    SWIG_fail;
  }
  dimension_numbers.set_output_batch_dimension(value);

  if (!GetIntAttr($input, "output_feature_dimension", &value)) {
    SWIG_fail;
  }
  dimension_numbers.set_output_feature_dimension(value);

  if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) {
    SWIG_fail;
  }
  dimension_numbers.set_kernel_output_feature_dimension(value);

  if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) {
    SWIG_fail;
  }
  dimension_numbers.set_kernel_input_feature_dimension(value);

  if (!HandleRepeatedInt64Attribute(
        $input, "input_spatial_dimensions",
        dimension_numbers.mutable_input_spatial_dimensions())) {
    SWIG_fail;
  }
  if (!HandleRepeatedInt64Attribute(
        $input, "kernel_spatial_dimensions",
        dimension_numbers.mutable_kernel_spatial_dimensions())) {
    SWIG_fail;
  }
  if (!HandleRepeatedInt64Attribute(
        $input, "output_spatial_dimensions",
        dimension_numbers.mutable_output_spatial_dimensions())) {
    SWIG_fail;
  }

  $1 = &dimension_numbers;
}

// GatherDimensionNumbers

%typemap(in) const GatherDimensionNumbers&
    (GatherDimensionNumbers dimension_numbers) {
  if (!HandleRepeatedInt64Attribute(
        $input, "offset_dims",
        dimension_numbers.mutable_offset_dims())) {
    SWIG_fail;
  }
  if (!HandleRepeatedInt64Attribute(
        $input, "collapsed_slice_dims",
        dimension_numbers.mutable_collapsed_slice_dims())) {
    SWIG_fail;
  }
  if (!HandleRepeatedInt64Attribute(
        $input, "start_index_map",
        dimension_numbers.mutable_start_index_map())) {
    SWIG_fail;
  }

  int64 value;
  if (!GetIntAttr($input, "index_vector_dim", &value)) {
    SWIG_fail;
  }
  dimension_numbers.set_index_vector_dim(value);

  $1 = &dimension_numbers;
}

// ScatterDimensionNumbers

%typemap(in) const ScatterDimensionNumbers&
    (ScatterDimensionNumbers dimension_numbers) {
  if (!HandleRepeatedInt64Attribute(
        $input, "update_window_dims",
        dimension_numbers.mutable_update_window_dims())) {
    SWIG_fail;
  }
  if (!HandleRepeatedInt64Attribute(
        $input, "inserted_window_dims",
        dimension_numbers.mutable_inserted_window_dims())) {
    SWIG_fail;
  }
  if (!HandleRepeatedInt64Attribute(
        $input, "scatter_dims_to_operand_dims",
        dimension_numbers.mutable_scatter_dims_to_operand_dims())) {
    SWIG_fail;
  }

  int64 value;
  if (!GetIntAttr($input, "index_vector_dim", &value)) {
    SWIG_fail;
  }
  dimension_numbers.set_index_vector_dim(value);

  $1 = &dimension_numbers;
}

// Span<const ReplicaGroup>

%typemap(in) absl::Span<const ReplicaGroup >
    (std::vector<ReplicaGroup > temps) {
  if (!PySequence_Check($input)) {
    PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
    SWIG_fail;
  }
  const int size = PySequence_Size($input);
  temps.reserve(size);
  for (int i = 0; i < size; ++i) {
    PyObject* o = PySequence_GetItem($input, i);
    ReplicaGroup rgrp;
    if (!HandleRepeatedInt64Attribute(
            o, "replica_ids",
            rgrp.mutable_replica_ids())) {
        SWIG_fail;
    }
    temps.push_back(rgrp);
    Py_DECREF(o);
  }
  $1 = temps;
}
